---
-- 规则变量

local util = require("lib.util")
local string_util = require("lib.string_util")

local _M = {}

_M.VAR_PHASE_IP = "ip"
_M.VAR_PHASE_UA = "ua"
_M.VAR_PHASE_URL = "url"
_M.VAR_PHASE_COOKIE = "cookie"
_M.VAR_PHASE_ARGS = "args"
_M.VAR_PHASE_POST = "post"
_M.VAR_PHASE_HEADER = "header"
_M.VAR_PHASE_FILE = "file"

-- 规则变量名与变量元素的分隔符
--      如: request_headers:user-agent 表示只验证request_headers变量中的 ['user-agent']
local var_sep = ":"

local allow_vars = {
    -- 来访IP
    ip = {
        phase = _M.VAR_PHASE_IP,
        key = "ip",
    },
    args = {
        phase = _M.VAR_PHASE_ARGS,
        key = "args_get",
    },
    post = {
        phase = _M.VAR_PHASE_POST,
        key = "args_post",
    },
    header = {
        phase = _M.VAR_PHASE_HEADER,
        key = "request_headers",
    },
    user_agent = {
        phase = _M.VAR_PHASE_UA,
        key = "ua",
    },
    cookie = {
        phase = _M.VAR_PHASE_COOKIE,
        key = "cookie",
    },
    -- 查询字符串参数
    args_get = {
        phase = _M.VAR_PHASE_ARGS
    },
    -- 查询字符串参数的名称
    args_get_names = {
        phase = _M.VAR_PHASE_ARGS
    },
    -- post请求体的参数
    args_post = {
        phase = _M.VAR_PHASE_POST
    },
    -- post请求体参数的名称
    args_post_names = {
        phase = _M.VAR_PHASE_POST
    },
    -- 包含请求URI的查询字符串部分。 `query_string` 中的值始终是原始提供的，不进行URL解码。
    query_string = {
        phase = _M.VAR_PHASE_ARGS
    },
    -- 包含原始请求体。当检测到 `application / x-www-form-urlencoded` 内容类型时。
    request_body = {
        phase = _M.VAR_PHASE_POST
    },
    -- 所有请求 `cookie` 的集合（仅包含值）
    request_cookies = {
        phase = _M.VAR_PHASE_COOKIE
    },
    -- 所有请求 `cookie` 的名称的集合。
    request_cookies_names = {
        phase = _M.VAR_PHASE_COOKIE
    },
    -- 包含不带查询字符串部分的相对请求URL
    request_filename = {
        phase = _M.VAR_PHASE_URL
    },
    -- 所有请求头的集合
    request_headers = {
        phase = _M.VAR_PHASE_HEADER
    },
    -- 所有请求头的名称的集合
    request_headers_names = {
        phase = _M.VAR_PHASE_HEADER
    },
    -- 包含查询字符串数据在内的完整请求URL（例如，`/index.php?p=X`）。但是，它永远不会包含域名，即使它是在请求行上提供的
    request_uri = {
        phase = _M.VAR_PHASE_URL
    },
    -- 与 "request_uri" 相同，但如果在请求行上提供了域名，则包含域名（例如，`http://www.example.com/index.php?p=X` ）
    request_uri_raw = {
        phase = _M.VAR_PHASE_URL
    },
    -- 包含原始文件名的集合。仅适用于检查通过 `multipart/form-data` 形式上传的请求
    files = {
        phase = _M.VAR_PHASE_POST
    },
    -- 包含用于文件上载的表单字段列表。仅适用于检查通过 `multipart/form-data` 形式上传的请求
    files_names = {
        phase = _M.VAR_PHASE_POST
    },

    -- -- 完整请求行（包括请求方法和HTTP版本信息）
    -- request_line = {
    --     phase = ''
    -- },
    -- -- 请求方法
    -- request_method = {
    --     phase = ''
    -- },
    -- -- 请求协议的版本信息
    -- request_protocol = {
    --     phase = ''
    -- },
}

---验证规则变量var是否在允许列表中
---     - true: 允许的var
---     - false: 非法var
---@param var string    待验证规则变量var(不区分大小写)
---@return boolean
_M.in_allow = function(var)
    local _var = string.lower(var)
    if not allow_vars[_var] then
        return false
    end
    return true
end

---获取所有规则变量var对应的var_key
---@return table
_M.get_all_var_keys = function()
    local result = {}
    for k, v in pairs(allow_vars) do
        if v["key"] then
            result[k] = v["key"]
        else
            result[k] = k
        end
    end
    return result
end

---获取规则变量var对应的var_key
---@return string|nil
---@return string|nil err_info
_M.get_var_key = function(var)
    local _var = string.lower(var)
    if not allow_vars[_var] then
        return nil, var .. " invalid."
    end
    return (allow_vars[_var]["key"] or _var)
end

---获取WAF各检测阶段分类规则的键名称
---     示例: {  ip = {"ip"},
---             url = { "request_uri", "request_uri_raw", "query_string" },
---             header = { "request_headers", "request_headers_names", ... }
---           }
---@return table
_M.get_phase_rule_keys = function()
    local result = {}
    for k, item in pairs(allow_vars) do
        local phase = (item.phase or "others")
        if not result[phase] then
            result[phase] = {}
        end
        local exist = false
        local insert_key = (item["key"] or k)
        for _, v in ipairs(result[phase]) do
            if v == insert_key then
                exist = true
            end
        end
        if not exist then
            table.insert(result[phase], insert_key)
        end
    end
    return result
end

---分离由":"标识的规则变量及变量下标
---     没有":"直接返回原始规则变量
---     比如: "header:content-type"
---         返回  "header", "content-type"
---@param var_origin string 待分离的原始规则变量
---@return string
---@return nil|string
_M.var_son_split = function(var_origin)
    local from, to = string.find(var_origin, var_sep)

    local var_value = nil
    local param_value = nil
    if not from then
        -- 没有 ":" 分隔符
        var_value = var_origin
    else
        -- 有分隔符 ":" 规则变量取前一部分
        var_value = string.sub(var_origin, 1, from - 1)
        param_value = string.sub(var_origin, to + 1)
    end
    return var_value, param_value
end

---验证var合法性(如果var为数组,只要有一个不合法就会返回,后续元素不再验证)
---     返回 true/false, false时信息
---@param var_value string | table
---@return boolean
---@return string | nil
_M.is_valid = function(var_value)
    local var_type = type(var_value)
    if var_type == "nil" then
        return false, "rule:var is empty."
    elseif var_type == "table" then
        if #var_value < 1 then
            return false, "rule:var is empty table."
        end
        for _, v in ipairs(var_value) do
            local res, err = _M.is_valid(v)
            if not res then
                return res, err
            end
        end
    else
        if string.match(var_value, "^%s*$") then
            return false, "rule:var is empty string."
        end

        local value, _ = _M.var_son_split(var_value)
        local v = util.kebabcase_to_underscorecase(value)
        if not _M.in_allow(v) then
            return false, "rule:var invalid."
        end
    end

    return true
end

return _M
